import torch
import torch.nn as nn
from torch import float16


##########################################################################################################################################
# net metric utils
##########################################################################################################################################
@torch.no_grad()
def count_params(net):
    total = 0
    for p in net.parameters():
        total+= torch.prod(torch.tensor(p.shape))
    return int(total)

@torch.no_grad()
def get_dynamic_ranks(model):
        ranks = []
        for l in model:
            if hasattr(l,'rank'):
                ranks.append(l.rank)
        return ranks

@torch.no_grad()
def accuracy_top_k(output, target, topk=(1,)):
    """Computes the non normalized version of topk accuracies"""
    maxk = max(topk)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    # res = []
    res = dict()
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        # res.append(correct_k)   #correct_k.mul_(100.0 / batch_size)
        res[k] = correct_k
    return res
